/*    Copyright 2010 Tobias Marschall
 *
 *    This file is part of MoSDi.
 *
 *    MoSDi is free software: you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation, either version 3 of the License, or
 *    (at your option) any later version.
 *
 *    MoSDi is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with MoSDi.  If not, see <http://www.gnu.org/licenses/>.
 */

package mosdi.tests;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import junit.framework.TestCase;
import mosdi.discovery.EvaluatedPattern;
import mosdi.discovery.MatchCountSearch;
import mosdi.discovery.MotifFinder;
import mosdi.discovery.ObjectiveFunction;
import mosdi.discovery.SequenceCountThresholdSearch;
import mosdi.discovery.MotifFinder.SearchState;
import mosdi.discovery.objectives.OccurrenceCountObjective;
import mosdi.discovery.strategies.ThresholdSearch;
import mosdi.distributions.PoissonDistribution;
import mosdi.fa.Alphabet;
import mosdi.fa.CDFA;
import mosdi.fa.DFAFactory;
import mosdi.fa.GeneralizedString;
import mosdi.fa.IIDTextModel;
import mosdi.index.SuffixTree;
import mosdi.paa.ClumpSizeCalculator;
import mosdi.util.BitArray;
import mosdi.util.Iupac;
import mosdi.util.IupacStringConstraints;
import mosdi.util.iterators.IupacAbelianPatternGenerator;
import mosdi.util.iterators.IupacPatternGenerator;
import mosdi.util.iterators.IupacPatternIterator;

public class MotifFinderTest extends TestCase {

	private static int countMatches(int[] sequence, BitArray[] generalizedAlphabet, int[] pattern) {
		int matches = 0;
		for (int i=0; i<sequence.length; ++i) {
			int j=0;
			for (; (j<pattern.length) && (i+j<sequence.length); ++j) {
				if (sequence[i+j]==-1) break;
				if (!generalizedAlphabet[pattern[j]].get(sequence[i+j])) break;
			}
			if (j==pattern.length) matches+=1;
		}
		return matches;
	}
	
	private class TestSearch implements MotifFinder.SearchSpecification {
		private int knownPrefixLength;
		private int[] pattern;
		private SearchState searchState;
		private int n;
		@Override
		public void initialize(SearchState searchState) {
			this.searchState = searchState;
			this.knownPrefixLength = 0;
			this.pattern = new int[searchState.getStringLength()];
			Arrays.fill(pattern, -1);
			this.n = 0;
		}
		@Override
		public void updatePattern(int newCharacter, int leftmostChangedPosition) {
			pattern[leftmostChangedPosition] = newCharacter;
			knownPrefixLength = leftmostChangedPosition+1;
		}
		private void assertPrefix(int prefixLength) {
			assertTrue(knownPrefixLength>=prefixLength);
			for (int i=0; i<prefixLength; ++i) {
				assertEquals(pattern[i], searchState.getPattern()[i]);
			}
		}
		@Override
		public boolean check(int prefixLength, int[] nodes) {
			assertPrefix(prefixLength);
			n+=1;
			return n%7 != 0;
		}

		@Override
		public void evaluateCandidate(int[] nodes) {
			assertPrefix(searchState.getStringLength());
		}
		@Override
		public List<EvaluatedPattern> getResults() {
			return null;
		}
	}
	
	public void testSearch() {
		String sequence = "ATGAGTAACTCGAACTTCTCCATCGAGGAACACTTCCCAGATATGTGGGATGCCATCATGCACGATTGGCTTGCCGATAGCTCGTCGGCTAATCCCGATC";
		Alphabet alphabet = Alphabet.getDnaAlphabet();
		int[] intSequence = alphabet.buildIndexArray(sequence, true);
		SuffixTree suffixTree = new SuffixTree(intSequence, alphabet.size());
		TestSearch search = new TestSearch();
		int patternLength = 7;
		int[] maxFreq = {7,1,1,2};
		int[] minFreq = {0,0,0,0};
		IupacStringConstraints constraints = new IupacStringConstraints(minFreq, maxFreq);
		MotifFinder motifFinder = new MotifFinder(suffixTree, Iupac.asGeneralizedAlphabet(), false);
		motifFinder.findIupacPatterns(patternLength, constraints, search);
	}
	
	public void testFindAbelianPatternInstances() {
		String sequence = "ATGAGTAACTCGAACTTCTCCATCGAGGAACACTTCCCAGATATGTGGGATGCCATCATGCACGATTGGCTTGCCGATAGCTCGTCGGCTAATCCCGATC";
		int[] intSequence = new int[sequence.length()+1];
		Alphabet dnaAlphabet = Alphabet.getDnaAlphabet();
		for (int i=0; i<sequence.length(); ++i) {
			intSequence[i]=dnaAlphabet.getIndex(sequence.charAt(i));
		}
		intSequence[sequence.length()]=-1;
		SuffixTree suffixTree = new SuffixTree(intSequence, dnaAlphabet.size());
		int minMatches =  14;
		int length = 4;
		int[] maxFreq = {2,1,1,1};
		int[] minFreq = {0,0,0,0};
		IupacAbelianPatternGenerator apg = new IupacAbelianPatternGenerator(length, minFreq, maxFreq, 1.0);
		Alphabet iupacAlphabet = apg.getAlphabet();
		BitArray[] generalizedAlphabet = Iupac.asGeneralizedAlphabet();
		int[] occurrenceCountAnnotation = suffixTree.calcOccurrenceCountAnnotation();
		MatchCountSearch search = new MatchCountSearch(minMatches, occurrenceCountAnnotation, null);
		Map<String,Integer> map = new HashMap<String,Integer>();
		MotifFinder motifFinder = new MotifFinder(suffixTree, generalizedAlphabet, false);
		for (int[] abelianPattern : apg) {
			motifFinder.findAbelianPatternInstances(abelianPattern, search);
			List<EvaluatedPattern> l = search.getResults();
			for (EvaluatedPattern m : l) {
				String s = iupacAlphabet.buildString(m.getPattern());
				assertFalse(map.containsKey(s));
				map.put(s, m.getScore());
			}
		}
		int n = 0;
		for (int[] s : new IupacPatternGenerator(length, minFreq, maxFreq)) {
			int matches = countMatches(intSequence, generalizedAlphabet, s);
			if (matches>=minMatches) {
				String string = iupacAlphabet.buildString(s);
				assertTrue(map.containsKey(string));
				assertEquals(matches, (int)map.get(string));
				n+=1;
			}
		}
		assertEquals(n,map.size());
	}

	public void testFindAbelianPatternInstancesWithAnnotations() {
		String sequence = "ATGAGTAACTCGAACTTCTCCATCGAGGAACACTTCCCAGATATGTGGGATGCCATCATGCACGATTGGCTTGCCGATAGCTCGTCGGCTAATCCCGATC";
		int[] intSequence = new int[sequence.length()+1];
		Alphabet alphabet = Alphabet.getDnaAlphabet();
		for (int i=0; i<sequence.length(); ++i) {
			intSequence[i]=alphabet.getIndex(sequence.charAt(i));
		}
		intSequence[sequence.length()]=-1;
		SuffixTree suffixTree = new SuffixTree(intSequence, alphabet.size());
		int minMatches =  14;
		int length = 4;
		int[] maxFreq = {2,1,1,1};
		int[] minFreq = {0,0,0,0};
		IupacAbelianPatternGenerator apg = new IupacAbelianPatternGenerator(length, minFreq, maxFreq, 1.0);
		Alphabet iupacAlphabet = apg.getAlphabet();
		BitArray[] generalizedAlphabet = Iupac.asGeneralizedAlphabet();
		int[] occurrenceCountAnnotation = suffixTree.calcOccurrenceCountAnnotation();
		int[] maxMatchCountAnnotation = suffixTree.calcMaxMatchCountAnnotation(length);
		MatchCountSearch search = new MatchCountSearch(minMatches, occurrenceCountAnnotation, maxMatchCountAnnotation);
		Map<String,Integer> map = new HashMap<String,Integer>();
		MotifFinder motifFinder = new MotifFinder(suffixTree, generalizedAlphabet, false);
		for (int[] abelianPattern : apg) {
			motifFinder.findAbelianPatternInstances(abelianPattern, search);
			List<EvaluatedPattern> l = search.getResults();
			// List<SuffixTree.Matches> l = suffixTree.findAbelianPatternInstances(abelianPattern, generalizedAlphabet, minMatches, null);
			for (EvaluatedPattern m : l) {
				String s = iupacAlphabet.buildString(m.getPattern());
				assertFalse(map.containsKey(s));
				map.put(s, m.getScore());
			}
		}
		int n = 0;
		for (int[] s : new IupacPatternGenerator(length, minFreq, maxFreq)) {
			int matches = countMatches(intSequence, generalizedAlphabet, s);
			if (matches>=minMatches) {
				String string = iupacAlphabet.buildString(s);
				assertTrue(map.containsKey(string));
				assertEquals(matches, (int)map.get(string));
				n+=1;
			}
		}
		assertEquals(n,map.size());
	}
	
	private static int countMatchingSequences(int[] sequence, BitArray[] generalizedAlphabet, int[] pattern) {
		int totalMatches = 0;
		int n = 0;
		for (int i=0; i<sequence.length; ++i) {
			if (sequence[i]==-1) {
				n = 0;
				continue;
			}
			if (n>0) continue;
			int j=0;
			for (; (j<pattern.length) && (i+j<sequence.length); ++j) {
				if (sequence[i+j]==-1) break;
				if (!generalizedAlphabet[pattern[j]].get(sequence[i+j])) break;
			}
			if (j==pattern.length) {
				n=1;
				totalMatches+=n;
			}
		}
		return totalMatches;
	}
	
	/** Test motif discovery in multiple sequences. */
	public void testFindAbelianPatternInstancesMulti() {
		String sequence = "ATGAGTAACTCGAACTTCTC$CATCGAGGAACACTTCC$AGATATGTGGGATGCCAT$CATGCACGATTGGCT$TGCCGATAGCTCGTCG$GCTAATCCCGATC$";
		Alphabet sentinelAlphabet = new Alphabet(Arrays.asList('$','A', 'C', 'G', 'T'));
		int[] intSequence = sentinelAlphabet.buildIndexArray(sequence);
		// change '$' to -1
		for (int i=0; i<intSequence.length; ++i) intSequence[i]-=1;
		SuffixTree suffixTree = new SuffixTree(intSequence, 4);
		
		boolean considerReverse = false;
		BitArray[] sequenceOccurrenceAnnotation = suffixTree.calcSequenceOccurrenceAnnotation(considerReverse?2:1);
		BitArray[] generalizedAlphabet = Iupac.asGeneralizedAlphabet();

		int minMatches =  4;
		int length = 4;
		int[] maxFreq = {2,1,1,1};
		int[] minFreq = {0,0,0,0};
		IupacAbelianPatternGenerator apg = new IupacAbelianPatternGenerator(length, minFreq, maxFreq, 1.0);
		Alphabet iupacAlphabet = Alphabet.getIupacAlphabet();
		Map<String,Integer> map = new HashMap<String,Integer>();
		MotifFinder motifFinder = new MotifFinder(suffixTree, generalizedAlphabet, considerReverse);
		for (int[] abelianPattern : apg) {
			SequenceCountThresholdSearch search = new SequenceCountThresholdSearch(minMatches, sequenceOccurrenceAnnotation);
			motifFinder.findAbelianPatternInstances(abelianPattern, search);
			List<EvaluatedPattern> l = search.getResults();
			for (EvaluatedPattern m : l) {
				String s = iupacAlphabet.buildString(m.getPattern());
				assertFalse(map.containsKey(s));
				map.put(s, m.getScore());
			}
		}
		
		int n = 0;
		for (int[] s : new IupacPatternGenerator(length, minFreq, maxFreq)) {
			int matches = countMatchingSequences(intSequence, generalizedAlphabet, s);
			if (matches>=minMatches) {
				String string = iupacAlphabet.buildString(s);
				assertTrue(map.containsKey(string));
				assertEquals(matches, (int)map.get(string));
				n+=1;
			}
		}
		assertEquals(n,map.size());
	}
	
	public void testFindIupacPatterns() {
		String sequence = "ATGAGTAACTCGAACTTCTCCATCGAGGAACACTTCCCAGATATGTGGGATGCCATCATGCACGATTGGCTTGCCGATAGCTCGTCGGCTAATCCCGATC";
		List<String> sequences = new ArrayList<String>();
		sequences.add(sequence);
		int[] intSequence = new int[sequence.length()+1];
		final Alphabet dnaAlphabet = Alphabet.getDnaAlphabet();
		final Alphabet iupacAlphabet = Alphabet.getIupacAlphabet();
		for (int i=0; i<sequence.length(); ++i) {
			intSequence[i]=dnaAlphabet.getIndex(sequence.charAt(i));
		}
		intSequence[sequence.length()]=-1;
		SuffixTree suffixTree = new SuffixTree(intSequence, dnaAlphabet.size());
		int patternLength = 4;
		int[] maxFreq = {2,1,1,1};
		int[] minFreq = {0,0,0,0};
		IupacStringConstraints constraints = new IupacStringConstraints(minFreq, maxFreq);
		// BitArray[] generalizedAlphabet = Iupac.iupacToGeneralizedString(alphabet, "ABCDGHKMNRSTVWY").getPositions();
		int[] occurrenceCountAnnotation = suffixTree.calcOccurrenceCountAnnotation();
		int[] maxMatchCountAnnotation = suffixTree.calcMaxMatchCountAnnotation(patternLength);
		BitArray[] generalizedAlphabet = Iupac.asGeneralizedAlphabet();
		double[] charDist = {0.15, 0.35, 0.3, 0.2};
		double threshold = 1e-3;
		int effectiveLength = 0;
		for (String s : sequences) effectiveLength+=s.length()-patternLength+1;
		ObjectiveFunction objective = new OccurrenceCountObjective(new IIDTextModel(dnaAlphabet.size(), charDist), occurrenceCountAnnotation, maxMatchCountAnnotation, effectiveLength);
		ThresholdSearch search = new ThresholdSearch(objective, threshold);
		Map<String,Double> map = new HashMap<String,Double>();
		MotifFinder motifFinder = new MotifFinder(suffixTree, generalizedAlphabet, false);
		motifFinder.findIupacPatterns(patternLength, constraints, search);
		List<EvaluatedPattern> result = search.getResults();
		for (EvaluatedPattern m : result) {
			String s = iupacAlphabet.buildString(m.getPattern()); 
			map.put(s, m.getMinusLogPValue());
			// System.out.println(String.format("%s %d %e", s, m.getScore(), Math.exp(-m.getMinusLogPValue())));
		}
		int n = 0;
		Iterator<int[]> it = new IupacPatternIterator(patternLength, constraints);
		while (it.hasNext()) {
			int[] s = it.next();
			int matches = countMatches(intSequence, generalizedAlphabet, s);
			List<GeneralizedString> l = new ArrayList<GeneralizedString>();
			l.add(Iupac.toGeneralizedString(iupacAlphabet.buildString(s)));
			CDFA cdfa = DFAFactory.build(dnaAlphabet, l, 50000);
			double singleExpectation = 0.0;
			for (GeneralizedString p : l) singleExpectation+=p.getProbability(charDist);
			ClumpSizeCalculator csc = new ClumpSizeCalculator(new IIDTextModel(dnaAlphabet.size(), charDist), cdfa, s.length);
			double[] clumpSizeDist = csc.clumpSizeDistribution(20, 1e-30);
			double expectedClumpSize = 0.0;
			for (int i=1; i<clumpSizeDist.length; ++i) {
				expectedClumpSize+=clumpSizeDist[i]*i;
			}
			double expectation = singleExpectation * (sequence.length()-patternLength+1);
			double lambda = expectation/expectedClumpSize;
			PoissonDistribution poissonDist = new PoissonDistribution(lambda);
			double pValue = poissonDist.compoundPoissonPValue(clumpSizeDist, matches);
			if (pValue<=threshold) {
				String string = iupacAlphabet.buildString(s);
				// System.out.println(String.format("%s %d %e %e %e", string, matches, pValue, singleExpectation, expectedClumpSize));
				assertTrue(map.containsKey(string));
				assertEquals(-Math.log(pValue), map.get(string), 1e-4);
				n+=1;
			}
		}
		assertEquals(n,map.size());
	}
	
}
